# -*- coding: utf-8 -*-
import numpy as np
import os
import time
import cv2
from skimage import io
import sklearn.utils


def load_local_train_val(config=None):
    drivers_all = get_drivers()
    train_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    labels = []
    img_arrs = []
    drivers_id = []  # drivers in train fold
    for c in os.listdir(train_path):
        begin = time.time()
        flag = 0
        train_path_c = os.path.join(train_path, c)
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            drivers_id.append(drivers_all[img])
            img_path = os.path.join(train_path_c, img)
            # print os.path.exists(img_path)
            # time.sleep(1)
            img_arr = io.imread(img_path)
            img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
            if config['crop_right'] == 'True':
                img_arr = img_arr[80:400, 320:]
            if config['crop_center'] == 'True':
                img_arr = img_arr[:, 80:560]
            img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
            if config['to_gray'] == 'True':
                img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2GRAY)
            # cv2.imshow('img_arr', img_arr)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            labels.append(label)
            img_arrs.append(img_arr)
            flag += 1
        print flag
        end = time.time()
        print 'time: {0}'.format(end - begin)
    unique_drivers = sorted(list(set(drivers_id)))
    lable_narrays = np.asarray(labels, dtype=np.int32)
    image_narrays = np.asarray(img_arrs, dtype=np.float32)
    if config['to_gray'] == 'True':
        image_narrays = np.expand_dims(image_narrays, 3)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    return image_narrays, lable_narrays, drivers_id, unique_drivers


def load_local_train_val_multiscale(config=None):
    drivers_all = get_drivers()
    train_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    labels = []
    img_arrs = []
    drivers_id = []  # drivers in train fold
    for c in os.listdir(train_path):
        begin = time.time()
        flag = 0
        train_path_c = os.path.join(train_path, c)
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            img_path = os.path.join(train_path_c, img)
            # print os.path.exists(img_path)
            # time.sleep(1)
            img_arr = io.imread(img_path)
            img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
            if config['crop_center'] == 'True':
                img_arr = img_arr[:, 80:560]
            img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
            for xy_offset in [[0, 0], [32, 0], [16, 16], [0, 32], [32, 32]]:
                drivers_id.append(drivers_all[img])
                img_arr_ = img_arr[xy_offset[0]:xy_offset[0] + 224, xy_offset[1]:xy_offset[1] + 224]
                # print img_arr_.shape
                # cv2.imshow('img_arr', img_arr)
                # cv2.waitKey(0)
                # cv2.destroyAllWindows()
                labels.append(label)
                img_arrs.append(img_arr_)
            flag += 1
            if config['debug'] == 'True' and flag >= 100:
                break
        print flag
        end = time.time()
        print 'time: {0}'.format(end - begin)
    unique_drivers = sorted(list(set(drivers_id)))
    lable_narrays = np.asarray(labels, dtype=np.int32)
    image_narrays = np.asarray(img_arrs, dtype=np.float32)
    if config['to_gray'] == 'True':
        image_narrays = np.expand_dims(image_narrays, 3)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    return image_narrays, lable_narrays, drivers_id, unique_drivers


def load_test(config):
    test_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/test'
    flag = 0
    img_arrs = []
    img_names = []
    begin = time.time()
    for img in os.listdir(test_path):
        flag += 1
        if flag % 2500 == 0:
            middel = time.time()
            print '{0} images loaded, time: {1}'.format(flag, middel - begin)
            if config['debug'] == 'True':
                break
        img_names.append(img)
        img_path = test_path + '/' + img
        img_arr = io.imread(img_path)
        img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
        if config['crop_right'] == 'True':
            img_arr = img_arr[80:400, 320:]
        if config['crop_center'] == 'True':
            img_arr = img_arr[:, 80:560]
        img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
        if config['to_gray'] == 'True':
            img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2GRAY)
        img_arrs.append(img_arr)
    image_narrays = np.asarray(img_arrs, dtype=np.float32)
    if config['to_gray'] == 'True':
        image_narrays = np.expand_dims(image_narrays, 3)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    return image_narrays, img_names


def new_select_drivers(image_narrays, lable_narrays, drivers_id, drivers_list):
    data = []
    target = []
    index = []
    new_drivers_id = []
    for i in range(len(drivers_id)):
        if drivers_id[i] in drivers_list:
            data.append(image_narrays[i])
            target.append(lable_narrays[i])
            new_drivers_id.append(drivers_id[i])
            index.append(i)
    data = np.asarray(data, dtype=np.float32)
    target = np.asarray(target, dtype=np.float32)
    index = np.asarray(index, dtype=np.uint8)
    return data, target, new_drivers_id, index


def get_drivers():
    dr = dict()
    path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/driver_imgs_list.csv'
    print('Read drivers data')
    f = open(path, 'r')
    while True:
        line = f.readline()
        if line == '':
            break
        arr = line.strip().split(',')
        dr[arr[2]] = arr[0]
    f.close()
    return dr


def select_drivers(image_narrays, lable_narrays, drivers_id, drivers_list):
    data = []
    target = []
    index = []
    for i in range(len(drivers_id)):
        if drivers_id[i] in drivers_list:
            data.append(image_narrays[i])
            target.append(lable_narrays[i])
            index.append(i)
    data = np.asarray(data, dtype=np.float32)
    target = np.asarray(target, dtype=np.float32)
    index = np.asarray(index, dtype=np.uint8)
    return data, target, index


def resample_by_sortedID(data, target, new_drivers_id):
    s = []
    for i in range(len(data)):
        s.append((data[i], target[i], new_drivers_id[i]))
    ss = sorted(s, key=lambda e: e[2])
    for j in range(len(ss)):
        data[j] = ss[j][0]
        target[j] = ss[j][1]
        new_drivers_id[j] = ss[j][2]
    return data, target, new_drivers_id


def load_local_train_val_DEBUG(config=None, num_samples=100):
    drivers_all = get_drivers()
    train_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    labels = []
    img_arrs = []
    drivers_id = []  # drivers in train fold
    for c in os.listdir(train_path):
        begin = time.time()
        flag = 0
        train_path_c = os.path.join(train_path, c)
        print train_path_c
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            drivers_id.append(drivers_all[img])
            img_path = os.path.join(train_path_c, img)
            # time.sleep(1)
            # print img_path
            # print os.path.exists(img_path)
            img_arr = io.imread(img_path)
            img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2BGR)
            # img_arr = cv2.cv.LoadImage(img_path, cv2.CV_LOAD_IMAGE_COLOR)
            # print img_arr
            if config['crop_right'] == 'True':
                img_arr = img_arr[80:400, 320:]
            if config['crop_center'] == 'True':
                img_arr = img_arr[:, 80:560]
            img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
            if config['to_gray'] == 'True':
                img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGB2GRAY)
            # cv2.imshow('img_arr',img_arr)
            # cv2.waitKey(0)
            # cv2.destroyAllWindows()
            labels.append(label)
            img_arrs.append(img_arr)
            if flag >= num_samples:
                break
            flag += 1
        print flag
        end = time.time()
        print 'time: {0}'.format(end - begin)
    unique_drivers = sorted(list(set(drivers_id)))
    lable_narrays = np.asarray(labels, dtype=np.int32)
    image_narrays = np.asarray(img_arrs, dtype=np.float32)
    if config['to_gray'] == 'True':
        image_narrays = np.expand_dims(image_narrays, 3)
    image_narrays = np.transpose(image_narrays, (0, 3, 1, 2))
    return image_narrays, lable_narrays, drivers_id, unique_drivers


def load_train_val_by_class_a(config, split_ratio=0.7, Debug=True):
    train_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    train_labels = []
    train_img_arrs = []
    val_img_arrs = []
    val_labels = []
    for c in os.listdir(train_path):
        num_samples_in_c = 0
        begin = time.time()
        flag = 0
        train_path_c = train_path + '/' + c
        for i in os.listdir(train_path_c):
            num_samples_in_c += 1
        if Debug == True:
            num_samples_in_c = 200
        num_samples_train = int(num_samples_in_c * split_ratio)
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            flag += 1
            img_path = train_path_c + '/' + img
            img_arr = cv2.imread(img_path)
            if config['crop_center'] == 'True':
                img_arr = img_arr[:, 80:560]
            img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
            if config['to_gray'] == 'True':
                img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2GRAY)
            if flag <= num_samples_train:
                train_img_arrs.append(img_arr)
                train_labels.append(label)
            else:
                val_img_arrs.append(img_arr)
                val_labels.append(label)
            if Debug == True and flag >= 200:
                break
        print flag
        end = time.time()
        print 'time: {0}'.format(end - begin)
    train_lable_narrays = np.asarray(train_labels, dtype=np.int32)
    val_label_narrays = np.asarray(val_labels, dtype=np.int32)
    train_image_narrays = np.asarray(train_img_arrs, dtype=np.float32)
    val_image_narrays = np.asarray(val_img_arrs, dtype=np.float32)
    if config['to_gray'] == 'True':
        train_image_narrays = np.expand_dims(train_image_narrays, 3)
        val_image_narrays = np.expand_dims(val_image_narrays, 3)
    train_image_narrays = np.transpose(train_image_narrays, (0, 3, 1, 2))
    val_image_narrays = np.transpose(val_image_narrays, (0, 3, 1, 2))
    return (train_image_narrays, train_lable_narrays), (val_image_narrays, val_label_narrays)


def load_train_val_by_class_b(config, split_ratio=0.7, Debug=True):
    train_path = '/media/dell/delldisk/dell/wxm/Data/KaggleDDD/train'
    train_image_narrays = None
    train_lable_narrays = None
    val_image_narrays = None
    val_label_narrays = None
    f = False
    for c in os.listdir(train_path):
        labels = []
        img_arrs = []
        num_samples_in_c = 0
        begin = time.time()
        flag = 0
        train_path_c = train_path + '/' + c
        for i in os.listdir(train_path_c):
            num_samples_in_c += 1
        if Debug is True:
            num_samples_in_c = 200
        num_samples_train = int(num_samples_in_c * split_ratio)
        label = int(c[-1])
        for img in os.listdir(train_path_c):
            flag += 1
            img_path = train_path_c + '/' + img
            print img_path
            img_arr = cv2.imread(img_path)
            print img_arr
            if config['crop_center'] == 'True':
                img_arr = img_arr[:, 80:560]
            img_arr = cv2.resize(img_arr, tuple(config['resize']), interpolation=cv2.INTER_AREA)
            if config['to_gray'] == 'True':
                img_arr = cv2.cvtColor(img_arr, cv2.COLOR_RGBA2GRAY)
            if Debug is True and flag >= 200:
                break
            labels.append(label)
            img_arrs.append(img_arr)
        if f == False:
            f = True
            labels = np.asarray(labels, dtype=np.int32)
            labels = np.reshape(labels, (len(labels), 1))
            img_arrs = np.asarray(img_arrs, dtype=np.float32)
            img_arrs, labels = sklearn.utils.shuffle(img_arrs, labels, random_state=config['random_state'])
            train_image_narrays = img_arrs[0:num_samples_train]
            train_lable_narrays = labels[0:num_samples_train]
            val_image_narrays = img_arrs[num_samples_train:]
            val_label_narrays = labels[num_samples_train:]
        else:
            labels = np.asarray(labels, dtype=np.int32)
            labels = np.reshape(labels, (len(labels), 1))
            img_arrs = np.asarray(img_arrs, dtype=np.float32)
            img_arrs, labels = sklearn.utils.shuffle(img_arrs, labels, random_state=config['random_state'])
            train_image_narrays = np.row_stack([train_image_narrays, img_arrs[0:num_samples_train]])
            train_lable_narrays = np.row_stack([train_lable_narrays, labels[0:num_samples_train]])
            val_image_narrays = np.row_stack([val_image_narrays, img_arrs[num_samples_train:]])
            val_label_narrays = np.row_stack([val_label_narrays, labels[num_samples_train:]])
        print flag
        end = time.time()
        print 'label : {0} | time: {1}'.format(label, end - begin)
    if config['to_gray'] == 'True':
        train_image_narrays = np.expand_dims(train_image_narrays, 3)
        val_image_narrays = np.expand_dims(val_image_narrays, 3)
    train_image_narrays = np.transpose(train_image_narrays, (0, 3, 1, 2))
    val_image_narrays = np.transpose(val_image_narrays, (0, 3, 1, 2))
    return (train_image_narrays, train_lable_narrays), (val_image_narrays, val_label_narrays)
